package com.adu.music.db;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.CannotGetJdbcConnectionException;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;

import javax.sql.DataSource;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author duchuanchuan
 * @date 2016/12/3.
 */
public class AsyncNamedParameterJdbcTemplate extends NamedParameterJdbcTemplate {

    private static final Logger LOGGER = LoggerFactory.getLogger(AsyncNamedParameterJdbcTemplate.class);

    /**
     * key:插入数据库sql value:要插入的数据
     */
    private final ConcurrentMap<String, List<Map<String, ?>>> datasMap = new ConcurrentHashMap<>();
    private final ConcurrentMap<String, Object> lockMap = new ConcurrentHashMap<>();

    /**
     * 当缓存的数据数目大于该数值时， 就执行批量插入
     */
    private static final int LIMITED = 500;

    private volatile boolean closeFlag = false;

    private final AtomicInteger timeCount = new AtomicInteger();
    private DataSourceTransactionManager transactionManager;
    private DefaultTransactionDefinition def;

    /*public AsyncNamedParameterJdbcTemplate(JdbcOperations classicJdbcTemplate) {
        super(classicJdbcTemplate);
        autoFlush();
    }*/

    public AsyncNamedParameterJdbcTemplate(DataSource dataSource) {
        super(dataSource);
        transactionManager = new DataSourceTransactionManager(dataSource);
        def = new DefaultTransactionDefinition();
        def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRED);
        autoFlush();
    }

    /**
     * 开启定时写入任务
     */
    private void autoFlush() {
        // 每隔5s就执行一次插入
        new Thread(() -> {
            while (!closeFlag) {
                try {
                    if (timeCount.incrementAndGet() >= 60) {
                        timeCount.set(0);
                        flushAll(); // 所有sql都批量提交
                    }
                    TimeUnit.SECONDS.sleep(1);
                } catch (InterruptedException e) {
                    LOGGER.error("auto-flush-time-thread sleep error!", e);
                    break;
                }
            }
        }, "auto-flush-bytime-thread").start();
    }

    @Override
    public int update(String sql, Map<String, ?> paramMap) throws DataAccessException {
        if (closeFlag) {
            throw new CannotGetJdbcConnectionException(//
                    "ConcurrentNamedParameterJdbcTemplate已经被关闭", //
                    new SQLException("ConcurrentNamedParameterJdbcTemplate已经被关闭"));
        }

        synchronized (getSqlLock(sql)) {
            if (!datasMap.containsKey(sql)) {
                datasMap.put(sql, new ArrayList<>());
            }

            datasMap.get(sql).add(paramMap);
            if (datasMap.get(sql).size() >= LIMITED) {
                timeCount.set(0);
                flush(sql, "count", System.currentTimeMillis());
            }
        }

        return 1;
    }

    /**
     * 获取一条SQL对应的锁
     * @param sql sql语句
     * @return 该sql语句对应的锁
     */
    private Object getSqlLock(String sql) {
        Object lock = lockMap.get(sql);
        if (lock != null) {
            return lock;
        }

        Object newLock = new Object();

        lock = lockMap.putIfAbsent(sql, newLock);
        if (lock == null) {
            lock = newLock;
        }

        return lock;
    }

    private void flushAll() {
        for (String sql : datasMap.keySet()) {
            flush(sql, "time", System.currentTimeMillis());
        }
    }

    @SuppressWarnings("unchecked")
    private void flush(String sql, String src, long start) {
        List<Map<String, ?>> dataOfSql;
        synchronized (getSqlLock(sql)) {// 加锁开始，尽量减小锁粒度
            dataOfSql = datasMap.get(sql);
            datasMap.put(sql, new ArrayList<>());
        }// 加锁结束
        if (dataOfSql != null && dataOfSql.size() > 0) {
            try {
                TransactionStatus status = transactionManager.getTransaction(def);
                batchUpdate(sql, dataOfSql.toArray(new Map[0]));
                transactionManager.commit(status);
                LOGGER.info("batchUpdate_sql_{},size={},num={},sql={},allTime={}", src, datasMap.size(),
                        dataOfSql.size(), (sql != null && sql.length() > 50 ? sql.substring(0, 50) : sql),
                        (System.currentTimeMillis() - start));
            } catch (Throwable t) {
                LOGGER.error("批量插入时发生异常", t);
            }
        }
    }

    /**
     * 关闭ConcurrentNamedParameterJdbcTemplate，关闭之后不能继续执行update方法
     */
    public void close() {
        closeFlag = true;
        flushAll();// 关闭后在执行一次刷新操作
    }

    public void open() {
        closeFlag = false;
    }
}
